{
 "cells": [
  {
   "cell_type": "markdown",
   "source": [
    "<a target=\"_blank\" href=\"https://colab.research.google.com/github/ant-research/EasyTemporalPointProcess/blob/main/notebooks/easytpp_1_dataset.ipynb\">\n",
    "  <img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/>\n",
    "</a>"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%% md\n"
    }
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "# Tutorial 1: Dataset"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%% md\n"
    }
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "In this tutorial, we simply show how the dataset-related functionalities work in **EasyTPP**.\n",
    "\n",
    "\n",
    "Firstly, we install the package."
   ],
   "metadata": {
    "id": "26Wvh9rZbTcg",
    "pycharm": {
     "name": "#%% md\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "U-gIiMZqMPFy",
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "!pip install easy_tpp"
   ]
  },
  {
   "cell_type": "markdown",
   "source": [
    "Currently, there are two options to load the preprocessed dataset:\n",
    "- copy the pickle files from [Google Drive](https://drive.google.com/drive/folders/1f8k82-NL6KFKuNMsUwozmbzDSFycYvz7).\n",
    "- load the json fils from [HuggingFace](https://huggingface.co/easytpp).\n",
    "\n",
    "In the future the first way will be depreciated and the second way is recommended."
   ],
   "metadata": {
    "id": "I5YUvAc7bngQ",
    "pycharm": {
     "name": "#%% md\n"
    }
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "\n",
    "## Load pickle data files\n",
    "\n",
    "If we choose to use the pickle files as the sources, we can download the data files, put it under a data folder, specify the directory in the config file and run the training and prediction pipeline.\n",
    "\n",
    "\n",
    "Take taxi dataset for example, we put it this way:\n",
    "\n",
    "```\n",
    "data:\n",
    "  taxi:\n",
    "    data_format: pickle\n",
    "    train_dir:  ./data/taxi/train.pkl\n",
    "    valid_dir:  ./data/taxi/dev.pkl\n",
    "    test_dir:  ./data/taxi/test.pkl\n",
    "```\n",
    "\n",
    "See [experiment_config](https://github.com/ant-research/EasyTemporalPointProcess/blob/main/examples/configs/experiment_config.yaml) for the full example.\n",
    "\n"
   ],
   "metadata": {
    "id": "6zfSHKhDmFSS",
    "pycharm": {
     "name": "#%% md\n"
    }
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "## Load json data files\n",
    "\n",
    "\n",
    "The recommended way is to load data from HuggingFace, where all data have been preprocessed in json format and hosted in [EasyTPP Repo](https://huggingface.co/easytpp)."
   ],
   "metadata": {
    "id": "HHPDzqud2wJf",
    "pycharm": {
     "name": "#%% md\n"
    }
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "We use the official APIs to directly download and inspect the dataset."
   ],
   "metadata": {
    "id": "6HJd1lZB33mP",
    "pycharm": {
     "name": "#%% md\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "8sM6riIxQClw",
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "from datasets import load_dataset\n",
    "\n",
    "# we choose taxi dataset as it is relatively small\n",
    "dataset = load_dataset('easytpp/taxi')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "BZYUTFsDRHmL",
    "outputId": "478e4afb-6806-4266-83da-2f3c55bf93db",
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "output_type": "execute_result",
     "data": {
      "text/plain": [
       "DatasetDict({\n",
       "    train: Dataset({\n",
       "        features: ['time_since_last_event', 'type_event', 'time_since_start', 'dim_process', 'seq_len', 'seq_idx'],\n",
       "        num_rows: 1400\n",
       "    })\n",
       "    validation: Dataset({\n",
       "        features: ['time_since_last_event', 'type_event', 'time_since_start', 'dim_process', 'seq_len', 'seq_idx'],\n",
       "        num_rows: 200\n",
       "    })\n",
       "    test: Dataset({\n",
       "        features: ['time_since_last_event', 'type_event', 'time_since_start', 'dim_process', 'seq_len', 'seq_idx'],\n",
       "        num_rows: 400\n",
       "    })\n",
       "})"
      ]
     },
     "metadata": {},
     "execution_count": 3
    }
   ],
   "source": [
    "dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "NJKP0ATnv4_l",
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "dataset['train']['type_event'][0]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "an__K1qzmRSo",
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "To activate this loading process in the train/evaluation pipeline, similarly, we put the directory of huggingface repo in the config file, e.g.,\n",
    "\n",
    "```\n",
    "data:\n",
    "  taxi:\n",
    "    data_format: json\n",
    "    train_dir:  easytpp/taxi\n",
    "    valid_dir:  easytpp/taxi\n",
    "    test_dir:  easytpp/taxi\n",
    "```\n",
    "\n",
    "Note that we can also manually put the locally directory of json files in the config:\n",
    "\n",
    "```\n",
    "data:\n",
    "  taxi:\n",
    "    data_format: json\n",
    "    train_dir:  ./data/taxi/train.json\n",
    "    valid_dir:  ./data/taxi/dev.json\n",
    "    test_dir:  ./data/taxi/test.json\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "source": [],
   "metadata": {
    "id": "O2EvMK0x6aKY",
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "execution_count": null,
   "outputs": []
  }
 ],
 "metadata": {
  "colab": {
   "provenance": []
  },
  "kernelspec": {
   "display_name": "Python 3",
   "name": "python3"
  },
  "language_info": {
   "name": "python"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}